# Standard Library Imports
from collections import defaultdict

# Third-Party Imports
from tqdm import tqdm
from PIL import Image

# Project-Specific Imports
import CLIPAD
from modules import *
from utils import *


def get_attn_map(timesteps, blocks, tokenizer, prompt, word, height, width):
    prompt_ids = tokenizer(prompt).input_ids
    word_ids = tokenizer(word, add_special_tokens=False).input_ids

    for i in range(len(prompt_ids) - len(word_ids) + 1):
        if prompt_ids[i:i + len(word_ids)] == word_ids:
            indices = list(range(i, i + len(word_ids)))
            break

    maps = defaultdict(dict)

    for timestep in timesteps:
        attn_map = None
        names = list(attn_maps[timestep].keys())

        for name in names:
            if any(block in name for block in blocks):
                value = attn_maps[timestep][name]
                value = torch.mean(value, axis=0).squeeze(0)
                value = F.interpolate(value.to(dtype=torch.float32).unsqueeze(0), size=(height, width), mode='bilinear', align_corners=False).squeeze(0)

                attn_map = attn_map + value if attn_map is not None else value
        attn_map = attn_map[indices, :, :].mean(dim=0)

        maps[timestep] = attn_map
    return dict(maps)
    

def get_feature_maps(timesteps, blocks):
    maps = defaultdict(dict)

    for timestep in timesteps:
        for block in blocks:
            feature_map = feature_maps[timestep][block]
            feature_map = feature_map[1]

            maps[timestep][block] = feature_map
    return dict(maps)


def build_clip_memory_bank(
    dataloader,
    model,
    device,
):
    with torch.no_grad():
        features1 = []
        features2 = []

        for image_path in tqdm(dataloader):
            image_path = image_path[0]
            image = Image.open(image_path).convert("RGB")
            image = transform_image(image, 240, 240)

            image = image.unsqueeze(0).to(device)

            _, _, feature_map1, feature_map2 = model.encode_image(image)
            feature_map1 = feature_map1 / feature_map1.norm(dim=-1, keepdim=True)
            feature_map2 = feature_map2 / feature_map2.norm(dim=-1, keepdim=True)

            features1.append(feature_map1.squeeze())
            features2.append(feature_map2.squeeze())

        feature_gallery1 = torch.cat(features1, dim=0)
        feature_gallery2 = torch.cat(features2, dim=0)
        return (feature_gallery1, feature_gallery2)


def build_diff_memory_bank(
    dataloader,
    pipe,
    mask_image,
    object,
    template,
    timesteps,
    blocks
):
    with torch.no_grad():
        memory_bank = []

        for image_path in tqdm(dataloader):
            image_path = image_path[0]
            image = Image.open(image_path).convert("RGB").resize((512, 512))

            pipe(
                prompt=template.format(f"{object}"),
                image=image,
                mask_image=mask_image,
                height=512,
                width=512,
                ts=timesteps,
            )

            ref_feature_maps = get_feature_maps(
                timesteps=timesteps,
                blocks=blocks,
            )

            memory_bank.append(ref_feature_maps)
        return memory_bank


def get_clip_language_score(
    image_path,
    model,
    object,
    template,
    device,
):
    with torch.no_grad():
        normal_prompt = template.format(f"perfect {object}")
        abnormal_prompt = template.format(f"damaged {object}")

        normal_text_feature = model.encode_text(CLIPAD.tokenize(normal_prompt).to(device))
        normal_text_feature = normal_text_feature / normal_text_feature.norm(dim=-1, keepdim=True)
        abnormal_text_feature = model.encode_text(CLIPAD.tokenize(abnormal_prompt).to(device))
        abnormal_text_feature = abnormal_text_feature / abnormal_text_feature.norm(dim=-1, keepdim=True)

        normal_text_feature = torch.mean(normal_text_feature, dim=0, keepdim=True)
        normal_text_feature /= normal_text_feature.norm(dim=-1, keepdim=True)
        abnormal_text_feature = torch.mean(abnormal_text_feature, dim=0, keepdim=True)
        abnormal_text_feature /= abnormal_text_feature.norm(dim=-1, keepdim=True)

        text_features = torch.cat([normal_text_feature, abnormal_text_feature], dim=0)

        image = Image.open(image_path).convert("RGB")
        image = transform_image(image, 240, 240)

        image = image.unsqueeze(0).to(device)
        image_feature = model.encode_image(image)[0]
        image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

        text_probs = (15 * image_feature @ text_features.T).softmax(dim=-1)
        total_score = text_probs[0, 1]

        return total_score
    

def get_clip_language_score_map(
    image_path,
    model,
    object,
    template,
    device,
):
    with torch.no_grad():
        normal_prompt = template.format(f"perfect {object}")
        abnormal_prompt = template.format(f"damaged {object}")

        normal_text_feature = model.encode_text(CLIPAD.tokenize(normal_prompt).to(device))
        normal_text_feature = normal_text_feature / normal_text_feature.norm(dim=-1, keepdim=True)
        abnormal_text_feature = model.encode_text(CLIPAD.tokenize(abnormal_prompt).to(device))
        abnormal_text_feature = abnormal_text_feature / abnormal_text_feature.norm(dim=-1, keepdim=True)

        text_features = torch.cat([normal_text_feature, abnormal_text_feature], dim=0)

        image = Image.open(image_path).convert("RGB")
        image = transform_image(image, 240, 240)
        image = image.unsqueeze(0).to(device)

        _, token_features, _, _ = model.encode_image(image)
        token_features = token_features.squeeze(0)
        token_features = token_features / token_features.norm(dim=-1, keepdim=True)

        token_scores = (token_features @ text_features.T).softmax(dim=-1)[:, 1]
        scores = token_scores

        score_map = scores.reshape(model.visual.grid_size).unsqueeze(0).unsqueeze(0)
        score_map = F.interpolate(score_map, size=(512, 512), mode='bilinear', align_corners=False)

        score_map = score_map.squeeze()
        score_map = apply_gaussian_blur(score_map)

        return score_map


def get_clip_vision_score_map(
    image_path,
    model,
    memory_bank,
    device,
):
    with torch.no_grad():
        image = Image.open(image_path).convert("RGB")
        image = transform_image(image, 240, 240)

        image = image.unsqueeze(0).to(device)

        _, _, feature_map1, feature_map2 = model.encode_image(image)
        feature_map1 = feature_map1 / feature_map1.norm(dim=-1, keepdim=True)
        feature_map2 = feature_map2 / feature_map2.norm(dim=-1, keepdim=True)

        score1 = (1.0 - feature_map1 @ memory_bank[0].t()).min(dim=-1)[0] / 2.0
        score2 = (1.0 - feature_map2 @ memory_bank[1].t()).min(dim=-1)[0] / 2.0

        score_map = (0.5 * (score1 + score2)).reshape(model.visual.grid_size).unsqueeze(0).unsqueeze(0)
        score_map = F.interpolate(score_map, size=(512, 512), mode='bilinear', align_corners=False)
        score_map = score_map.squeeze()
        score_map = apply_gaussian_blur(score_map)
        return score_map


def get_diff_language_score_map(
    pipe,
    image_path,
    mask_image,
    object,
    states,
    template,
    timesteps,
    blocks,
):
    with torch.no_grad():
        image = Image.open(image_path).convert("RGB").resize((512, 512))
        score_maps = []

        for state in states:
            prompt = template.format(f"{object} with {state}")

            pipe(
                prompt=prompt,
                image=image,
                mask_image=mask_image,
                height=512,
                width=512,
                ts=timesteps
            )

            attn_maps = get_attn_map(
                timesteps=timesteps,
                blocks=blocks,
                tokenizer=pipe.tokenizer,
                prompt=prompt,
                word=state,
                height=512,
                width=512
            )

            score_map = torch.mean(torch.stack([attn_maps[timestep] for timestep in timesteps]), dim=0)
            score_maps.append(score_map)

        score_map = torch.mean(torch.stack(score_maps), dim=0)
        score_map = apply_gaussian_blur(score_map)
        return score_map


def get_diff_vision_score_map(
    pipe,
    image_path,
    mask_image,
    object,
    template,
    timesteps,
    blocks,
    memory_bank,
):
    with torch.no_grad():
        image = Image.open(image_path).convert("RGB").resize((512, 512))

        pipe(
            prompt=template.format(f"{object}"),
            image=image,
            mask_image=mask_image,
            height=512,
            width=512,
            ts=timesteps,
        )

        query_feature_maps = get_feature_maps(
            timesteps=timesteps,
            blocks=blocks,
        )

        score_maps = []

        for timestep, block in zip(timesteps, blocks):
            query_features = query_feature_maps[timestep][block]
            query_features = F.normalize(query_features.flatten(1).T, p=2, dim=1)
            ref_features = []

            for ref_feature_maps in memory_bank:
                ref = ref_feature_maps[timestep][block]
                ref = F.normalize(ref.flatten(1), p=2, dim=0)
                ref_features.append(ref)

            ref_features = torch.cat(ref_features, dim=1)
            score_map = torch.min(0.5 * (1 - torch.matmul(query_features, ref_features)), dim=1)[0]
            
            h = int(math.sqrt(score_map.shape[0]))
            score_map = score_map.reshape((h, h)).unsqueeze(0).unsqueeze(0)
            score_map = F.interpolate(score_map, size=(512, 512), mode='bilinear', align_corners=False)
            score_map = score_map.squeeze()
            score_maps.append(score_map)

        score_map = torch.mean(torch.stack(score_maps), dim=0)
        score_map = apply_gaussian_blur(score_map)
        return score_map
    